module Vert

using Optim

mutable struct Market{T}
    # demand parameters
    k::T
    r::T
    # cost parameters
    βg_11::T
    βg_21::T
    βg_12::T
    βg_22::T
    βp_11::T
    βp_21::T
    βp_12::T
    βp_22::T
    # conduct parameters
    θm_1::T
    θp_1::T
    θp_2::T
    θr::T
    # tax parameters
    τ_p::T
    τ_m::T
    τ_r::T    
end # Market

# call this Choices, not Equilibrium, because 
# of course we could have Choices that are not Equilibrium Choices :-)
mutable struct Choices{T}
    # quantities
    q_g1::T
    q_g2::T
    q_p1::T
    q_p2::T
    # prices
    p_m::T
    p_f::T
    p_r::T
end # Choices

function EvaluateMarketConditions(in::Choices, mkt::Market)
    # intermediates
    q = (in.q_p1 + in.q_p2 + in.q_g1 + in.q_g2) * 0.5
    
    mc_g1 = mkt.βg_11 + 2.0 * mkt.βg_21 * in.q_g1
    mc_g2 = mkt.βg_12 + 2.0 * mkt.βg_22 * in.q_g2

    mc_p1 = mkt.βp_11 + 2.0 * mkt.βp_21 * in.q_p1
    mc_p2 = mkt.βp_12 + 2.0 * mkt.βp_22 * in.q_p2

    mc_1 = mc_g1 + mc_p1
    mc_2 = mc_g2 + mc_p2

    # elasticity of demand for wholesale final goods
    ϵ_D = (1+mkt.τ_r) * mkt.r / (mkt.r - mkt.θr) * in.p_f / in.p_r * mkt.r

    #q_t = in.q_g1 - in.q_p1

    # set up the condition vector and fill in the conditions one at a time
    out = Vector{typeof(q)}(undef, 7)
     
    # market clearing
    out[1] = in.q_p1 + in.q_p2 - in.q_g1 - in.q_g2

    # demand for final goods from consumers
    out[2] = in.p_r*(1+mkt.τ_r)-(q/mkt.k)^(1/(-mkt.r))

    # final goods firm 1
    out[3] = (in.p_f - mc_1/(1-mkt.τ_p))/in.p_f*ϵ_D - mkt.θp_1

    # final goods firm 2
    out[4] = (in.p_f - mc_2/(1-mkt.τ_p))/in.p_f*ϵ_D - mkt.θp_2

    # cost minimization for the buyer
    out[5] = in.p_m - mc_g2

    # profit max for the seller
    out[6] = in.p_m-mc_g1/(1-mkt.τ_m) - 
                (mkt.θm_1*((in.p_f*(1-mkt.τ_p)-mc_p1)/(1-mkt.τ_m)-in.p_m))

    # conduct condition for retailer
    out[7] = (in.p_r - (1+mkt.τ_r)*in.p_f)/in.p_r * mkt.r - mkt.θr
    
    return out
end # function EvaluateMarketConditions

function CondWrap(x::Vector, m::Market)
    in = Choices{typeof(x[1])}(x...)
    out = [0.0]
    try
        out = EvaluateMarketConditions(in, m)
    catch y
        println("Got $y from $x with $m.")
        throw(y)
    end
    return sum(out .* out)
end # function CondWrap

function SolveEquilibrium(mkt::Market)
    # helper function
    f(x) = CondWrap(exp.(x), mkt)

    start = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

    result = optimize(f, start, Newton(); autodiff=:forward)

    return result
end # function SolveEquilibrium

function CalculateCalibrationMoments(x::Vector; debug::Bool=false)
    #= Params are
    Bg21, Bg22, Bp21, Bp22, theta_p
    =#
    mkt_prereform = Market(449.4767002, 1.1689, 0.0, x[1], 0.0, x[2], 0.0, x[3], 0.0, x[4], 0.0, x[5], x[5], x[6], 0.25, 0.25, 0.25)
    mkt_postreform = Market(449.4767002, 1.1689, 0.0, x[1], 0.0, x[2], 0.0, x[3], 0.0, x[4], 0.0, x[5], x[5], x[6], 0.0, 0.0, 0.37)

    res_pre = SolveEquilibrium(mkt_prereform)
    res_post = SolveEquilibrium(mkt_postreform)

    out_pre = Choices(exp.(Optim.minimizer(res_pre))...)
    out_post = Choices(exp.(Optim.minimizer(res_post))...)

    if debug
        println("Calibration moments at $x")
        println("\tPre-reform: $out_pre @ $(Optim.minimum(res_pre))")
        println("\tPost-reform: $out_post @ $(Optim.minimum(res_pre))")
    end
    
    # set up the moment vector and fill in the conditions one at a time
    out = Vector{typeof(x[1])}(undef, 6)

    # moment 1: fraction of vertical integration pre-reform
    out[1] = 1 - abs(out_pre.q_p1 - out_pre.q_g1)/(out_pre.q_p1 + out_pre.q_p2)

    # moment 2: manufacturer price, pre-reform
    out[2] = out_pre.p_f

    # moment 3: cultivator price, pre-reform
    out[3] = out_pre.p_m

    # moment 4: manufacturer ratio, pre-reform
    # this is people who sell post-reform divided by people who buy post-reform
    out[4] = out_pre.q_p1 / out_pre.q_p2

    # moment 5: cultivator ratio, pre-reform
    # we assume firm 1 is "not constrained" in planting, and firm 2 is
    out[5] = out_pre.q_g1 / out_pre.q_g2

    # moment 6: retail price, pre-reform
    out[6] = out_pre.p_r

    return out
end # function CalculateCalibrationMoments

function ConsumerSurplus(c::Choices, m::Market)
    q = c.q_p1 + c.q_p2
    return q*m.r*(q/m.k)^(-1/m.r)/(m.r-1) - q*c.p_f*(1+m.τ_f)
end

function TotalProfits(c::Choices, m::Market)
    tr1 = (c.q_g1-c.q_p1)*c.p_m + c.q_p1*c.p_f
    tr2 = (c.q_g2-c.q_p2)*c.p_m + c.q_p2*c.p_f

    tax1 = (c.q_g1 > c.q_p1 ? c.q_g1 - c.q_p1 : 0) * c.p_m * m.τ_m + c.q_p1 * c.p_f * m.τ_p
    tax2 = (c.q_g2 > c.q_p2 ? c.q_g2 - c.q_p2 : 0) * c.p_m * m.τ_m + c.q_p2 * c.p_f * m.τ_p

    tc1 = m.βg_21 * c.q_g1^2 + m.βp_21 * c.q_p1^2
    tc2 = m.βg_22 * c.q_g2^2 + m.βp_22 * c.q_p2^2

    prof1 = tr1 - tax1 - tc1
    prof2 = tr2 - tax2 - tc2

    return prof1 + prof2
end

function TotalSurplus(c::Choices, m::Market)
    return ConsumerSurplus(c, m) + TotalProfits(c, m)
end


end # module Vert